{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OmniSafe Tutorial - Environment Customization from Community\n",
    "\n",
    "OmniSafe: https://github.com/PKU-Alignment/omnisafe\n",
    "\n",
    "Documentation: https://omnisafe.readthedocs.io/en/latest/\n",
    "\n",
    "Gymnasium: https://github.com/Farama-Foundation/Gymnasium\n",
    "\n",
    "[Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is an open source Python library for developing and comparing reinforcement learning algorithms by providing a standard API to communicate between learning algorithms and environments, as well as a standard set of environments compliant with that API.\n",
    "\n",
    "## 引言\n",
    "\n",
    "在本节当中，我们将为您介绍如何将一个来自社区的已有环境嵌入OmniSafe中。[Gymnasium](https://github.com/Farama-Foundation/Gymnasium)提供的系列任务已被广泛应用至强化学习中。具体而言，本节将以[Pendulum-v1](https://gymnasium.farama.org/environments/classic_control/pendulum/)为例，展示如何将Gymnasium的任务嵌入OmniSafe。\n",
    "\n",
    "## 快速安装"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 通过pip安装（如果您已经安装，请忽略此段代码）\n",
    "!pip install omnisafe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 通过源代码安装（如果您已经安装，请忽略此段代码）\n",
    "## 克隆仓库\n",
    "!git clone https://github.com/PKU-Alignment/omnisafe\n",
    "%cd omnisafe\n",
    "\n",
    "## 完成安装\n",
    "!pip install -e ."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gymnasium任务嵌入\n",
    "环境嵌入需要的核心是为SafeRL智能体交互与训练提供足够的静态或动态信息，本节将详细介绍嵌入环境所必须定义的变量以及相应规范。我们将首先按照编写代码的逻辑顺序地展示整个嵌入过程，让您有一个初步的了解。然后我们将回顾所有代码，总结并整理您在自定义环境时需要进行的适配。\n",
    "\n",
    "\n",
    "### 快速开始\n",
    "首先，导入本教程所需要的所有外部变量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入必要的包\n",
    "from __future__ import annotations\n",
    "\n",
    "from typing import Any, ClassVar\n",
    "import gymnasium\n",
    "import torch\n",
    "import numpy as np\n",
    "import omnisafe\n",
    "\n",
    "from omnisafe.envs.core import CMDP, env_register, env_unregister\n",
    "from omnisafe.typing import DEVICE_CPU"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "随后，创建一个名为`ExampleMuJoCoEnv`的类，它需要继承的父类是`CMDP`。（这是因为我们想把环境的交互形式转换为CMDP的范式，您可以根据需要定义新的抽象类以实现新的范式）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExampleMuJoCoEnv(CMDP):\n",
    "    _support_envs: ClassVar[list[str]] = ['Pendulum-v1']  # 支持的任务名称\n",
    "\n",
    "    need_auto_reset_wrapper = True  # 是否需要 `AutoReset` Wrapper\n",
    "    need_time_limit_wrapper = True  # 是否需要 `TimeLimit` Wrapper\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        env_id: str,\n",
    "        num_envs: int = 1,\n",
    "        device: torch.device = DEVICE_CPU,\n",
    "        **kwargs: Any,\n",
    "    ) -> None:\n",
    "        super().__init__(env_id)\n",
    "        self._num_envs = num_envs\n",
    "        self._env = gymnasium.make(id=env_id, autoreset=True, **kwargs)  # 实例化环境对象\n",
    "        self._action_space = self._env.action_space  # 指定动作空间，以供算法层初始化读取\n",
    "        self._observation_space = self._env.observation_space  # 指定观测空间，以供算法层初始化读取\n",
    "        self._device = device  # 可选项，使用GPU加速。默认为CPU\n",
    "\n",
    "    def reset(\n",
    "        self,\n",
    "        seed: int | None = None,\n",
    "        options: dict[str, Any] | None = None,\n",
    "    ) -> tuple[torch.Tensor, dict[str, Any]]:\n",
    "        obs, info = self._env.reset(seed=seed, options=options)  # 重置环境\n",
    "        return (\n",
    "            torch.as_tensor(obs, dtype=torch.float32, device=self._device),\n",
    "            info,\n",
    "        )  # 将重置后的观测转换为torch tensor。\n",
    "\n",
    "    @property\n",
    "    def max_episode_steps(self) -> int | None:\n",
    "        return self._env.env.spec.max_episode_steps  # 返回环境每一幕的最大交互步数\n",
    "\n",
    "    def set_seed(self, seed: int) -> None:\n",
    "        self.reset(seed=seed)  # 设定环境的随机种子以实现可复现性\n",
    "\n",
    "    def render(self) -> Any:\n",
    "        return self._env.render()  # 返回环境渲染的图像\n",
    "\n",
    "    def close(self) -> None:\n",
    "        self._env.close()  # 训练结束后，释放环境实例\n",
    "\n",
    "    def step(\n",
    "        self,\n",
    "        action: torch.Tensor,\n",
    "    ) -> tuple[\n",
    "        torch.Tensor,\n",
    "        torch.Tensor,\n",
    "        torch.Tensor,\n",
    "        torch.Tensor,\n",
    "        torch.Tensor,\n",
    "        dict[str, Any],\n",
    "    ]:\n",
    "        obs, reward, terminated, truncated, info = self._env.step(\n",
    "            action.detach().cpu().numpy(),\n",
    "        )  # 读取与环境交互后的动态信息\n",
    "        cost = np.zeros_like(reward)  # Gymnasium并显式包含安全约束，此处仅为占位。\n",
    "        obs, reward, cost, terminated, truncated = (\n",
    "            torch.as_tensor(x, dtype=torch.float32, device=self._device)\n",
    "            for x in (obs, reward, cost, terminated, truncated)\n",
    "        )  # 将动态信息转换为torch tensor。\n",
    "        if 'final_observation' in info:\n",
    "            info['final_observation'] = np.array(\n",
    "                [\n",
    "                    array if array is not None else np.zeros(obs.shape[-1])\n",
    "                    for array in info['final_observation']\n",
    "                ],\n",
    "            )\n",
    "            info['final_observation'] = torch.as_tensor(\n",
    "                info['final_observation'],\n",
    "                dtype=torch.float32,\n",
    "                device=self._device,\n",
    "            )  # 将info中记录的上一幕final observation转换为torch tensor。\n",
    "\n",
    "        return obs, reward, cost, terminated, truncated, info"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "有关上述代码的具体含义，我们已提供了详细的注释说明。更详细的解释可参考[Tutorial 3: Environment Customization from Scratch](./3.Environment%20Customization.ipynb)。我们将要点总结如下：\n",
    "\n",
    "- **OmniSafe初始化需要的静态变量**\n",
    "\n",
    "| 静态信息 | 必须 | 定义 | 类型 | 例子 |\n",
    "|:---:|:---:|:---:|:---:|:---:|\n",
    "| `need_auto_reset_wrapper` | 是 | 是否需要 `AutoReset` Wrapper | `bool`变量 | `True` |\n",
    "| `need_time_limit_wrapper` | 是 | 是否需要 `TimeLimit` Wrapper | `bool`变量 | `True` |\n",
    "| `_action_space` | 是 | 动作空间 | `gymnasium.space.Box` | `Box(low=-1.0, high=1.0, shape=(2,)` |\n",
    "| `_observation_space` | 是 | 观测空间 | `gymnasium.space.Box` | `Box(low=-1.0, high=1.0, shape=(3,)` |\n",
    "| `max_episode_steps` | 是 | 环境每一幕的最大交互步数 | 带有`@property`装饰器的，返回值为`int`或`None`类型变量的函数 | 参考上方代码块 |\n",
    "| `_num_envs` | 否 | 并行环境数 | `int`变量 | 5 |\n",
    "| `_device` | 否 | torch计算设备 | `torch.device`变量 | `DEVICE_CPU` |\n",
    "\n",
    "- **OmniSafe需要环境提供的动态变量**\n",
    "\n",
    "OmniSafe的智能体主要通过`reset`和`step`函数与环境进行动态交互。您需要确保定制化环境的返回值类型、个数与顺序与上述例子一致，更具体地：\n",
    "\n",
    "| 动态信息 | 类型 | 个数 | 顺序 |\n",
    "|:---:|:---:|:---:|:---:|\n",
    "| `step` | `tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]` | 6 | `obs`, `reward`, `cost`, `terminated`, `truncated`, `info` |\n",
    "| `reset` | `tuple[torch.Tensor, dict[str, Any]]` | 2 | `obs`, `info` |\n",
    "\n",
    "- **注意事项**\n",
    "\n",
    "1. 尽管`_num_envs`与`_device`并不是必须指定的，但也请您在`__init__`函数中保留这两个参数的输入接口。\n",
    "2. `_num_envs`是实例化多个环境并行采样的高级参数，它表示实例化环境的数目。如果您的定制化环境同样支持并行数指定，请通过`_num_envs`指定，而不用再定义一个新的接口。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "随后，将上述环境通过注册装饰器`@env_register`注册入OmniSafe中，即可完成训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ExampleMuJoCoEnv has not been registered yet\n",
      "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Logging data to .</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/runs/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">PPOLag-</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">{Pendulum-v1}</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/seed-000-2024-04-09-15-05-55/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">progress.csv</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-05-55/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Save with config in config.json</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;33mSave with config in config.json\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008000; text-decoration-color: #008000\">INFO: Start training</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[32mINFO: Start training\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n",
       "for Jupyter support\n",
       "  warnings.warn('install \"ipywidgets\" for Jupyter support')\n",
       "</pre>\n"
      ],
      "text/plain": [
       "/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n",
       "for Jupyter support\n",
       "  warnings.warn('install \"ipywidgets\" for Jupyter support')\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008000; text-decoration-color: #008000\">Warning: trajectory cut off when rollout by epoch at </span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">200.0</span><span style=\"color: #008000; text-decoration-color: #008000\"> steps.</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m200.0\u001b[0m\u001b[32m steps.\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\"> Metrics                        </span>┃<span style=\"font-weight: bold\"> Value                 </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ Metrics/EpRet                  │ -1616.242431640625    │\n",
       "│ Metrics/EpCost                 │ 0.0                   │\n",
       "│ Metrics/EpLen                  │ 200.0                 │\n",
       "│ Train/Epoch                    │ 0.0                   │\n",
       "│ Train/Entropy                  │ 1.4185898303985596    │\n",
       "│ Train/KL                       │ 0.0007516025798395276 │\n",
       "│ Train/StopIter                 │ 1.0                   │\n",
       "│ Train/PolicyRatio/Mean         │ 0.9966228604316711    │\n",
       "│ Train/PolicyRatio/Min          │ 0.9966228604316711    │\n",
       "│ Train/PolicyRatio/Max          │ 0.9966228604316711    │\n",
       "│ Train/PolicyRatio/Std          │ 0.0075334208086133    │\n",
       "│ Train/LR                       │ 0.0                   │\n",
       "│ Train/PolicyStd                │ 0.9996514320373535    │\n",
       "│ TotalEnvSteps                  │ 200.0                 │\n",
       "│ Loss/Loss_pi                   │ 0.08751548826694489   │\n",
       "│ Loss/Loss_pi/Delta             │ 0.08751548826694489   │\n",
       "│ Value/Adv                      │ -0.398242324590683    │\n",
       "│ Loss/Loss_reward_critic        │ 16605.1796875         │\n",
       "│ Loss/Loss_reward_critic/Delta  │ 16605.1796875         │\n",
       "│ Value/reward                   │ 0.0049050007946789265 │\n",
       "│ Loss/Loss_cost_critic          │ 0.052194785326719284  │\n",
       "│ Loss/Loss_cost_critic/Delta    │ 0.052194785326719284  │\n",
       "│ Value/cost                     │ 0.07966174930334091   │\n",
       "│ Time/Total                     │ 0.2075355052947998    │\n",
       "│ Time/Rollout                   │ 0.1734788417816162    │\n",
       "│ Time/Update                    │ 0.033020973205566406  │\n",
       "│ Time/Epoch                     │ 0.20653653144836426   │\n",
       "│ Time/FPS                       │ 968.3539428710938     │\n",
       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Min │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Max │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Std │ 0.0                   │\n",
       "└────────────────────────────────┴───────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1mMetrics                       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue                \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ Metrics/EpRet                  │ -1616.242431640625    │\n",
       "│ Metrics/EpCost                 │ 0.0                   │\n",
       "│ Metrics/EpLen                  │ 200.0                 │\n",
       "│ Train/Epoch                    │ 0.0                   │\n",
       "│ Train/Entropy                  │ 1.4185898303985596    │\n",
       "│ Train/KL                       │ 0.0007516025798395276 │\n",
       "│ Train/StopIter                 │ 1.0                   │\n",
       "│ Train/PolicyRatio/Mean         │ 0.9966228604316711    │\n",
       "│ Train/PolicyRatio/Min          │ 0.9966228604316711    │\n",
       "│ Train/PolicyRatio/Max          │ 0.9966228604316711    │\n",
       "│ Train/PolicyRatio/Std          │ 0.0075334208086133    │\n",
       "│ Train/LR                       │ 0.0                   │\n",
       "│ Train/PolicyStd                │ 0.9996514320373535    │\n",
       "│ TotalEnvSteps                  │ 200.0                 │\n",
       "│ Loss/Loss_pi                   │ 0.08751548826694489   │\n",
       "│ Loss/Loss_pi/Delta             │ 0.08751548826694489   │\n",
       "│ Value/Adv                      │ -0.398242324590683    │\n",
       "│ Loss/Loss_reward_critic        │ 16605.1796875         │\n",
       "│ Loss/Loss_reward_critic/Delta  │ 16605.1796875         │\n",
       "│ Value/reward                   │ 0.0049050007946789265 │\n",
       "│ Loss/Loss_cost_critic          │ 0.052194785326719284  │\n",
       "│ Loss/Loss_cost_critic/Delta    │ 0.052194785326719284  │\n",
       "│ Value/cost                     │ 0.07966174930334091   │\n",
       "│ Time/Total                     │ 0.2075355052947998    │\n",
       "│ Time/Rollout                   │ 0.1734788417816162    │\n",
       "│ Time/Update                    │ 0.033020973205566406  │\n",
       "│ Time/Epoch                     │ 0.20653653144836426   │\n",
       "│ Time/FPS                       │ 968.3539428710938     │\n",
       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Min │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Max │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Std │ 0.0                   │\n",
       "└────────────────────────────────┴───────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "(-1616.242431640625, 0.0, 200.0)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@env_register\n",
    "@env_unregister  # 避免重复运行单元格时产生\"环境已注册\"报错\n",
    "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n",
    "    pass\n",
    "\n",
    "\n",
    "custom_cfgs = {\n",
    "    'train_cfgs': {\n",
    "        'total_steps': 200,\n",
    "    },\n",
    "    'algo_cfgs': {\n",
    "        'steps_per_epoch': 200,\n",
    "        'update_iters': 1,\n",
    "    },\n",
    "}\n",
    "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n",
    "agent.learn()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 高级使用\n",
    "除了上述使用方式外，来自社区的环境还可以享受OmniSafe的环境特定参数指定以及信息记录的特性。我们将详细展示具体操作方式。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 特定参数指定\n",
    "\n",
    "以`Pendulum-v1`为例，根据Gymnasium的官方文档，创建该任务时可指定一个特定参数为`g`，即重力加速度。我们首先来看看它的默认取值："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Logging data to .</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/runs/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">PPOLag-</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">{Pendulum-v1}</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/seed-000-2024-04-09-15-05-58/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">progress.csv</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-05-58/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Save with config in config.json</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;33mSave with config in config.json\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "10.0"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@env_register\n",
    "@env_unregister  # 避免重复运行单元格时产生\"环境已注册\"报错\n",
    "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n",
    "    def __getattr__(self, name: str) -> Any:\n",
    "        \"\"\"Get the attribute of the environment.\"\"\"\n",
    "        if name.startswith('_'):\n",
    "            raise AttributeError(f'attempted to get missing private attribute {name}')\n",
    "        return getattr(self._env, name)\n",
    "\n",
    "\n",
    "custom_cfgs = {\n",
    "    'train_cfgs': {\n",
    "        'total_steps': 200,\n",
    "    },\n",
    "    'algo_cfgs': {\n",
    "        'steps_per_epoch': 200,\n",
    "        'update_iters': 1,\n",
    "    },\n",
    "}\n",
    "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n",
    "agent.agent._env._env.g"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们实现了一个名为`__get_attr__`的魔法函数，用于调用并查看当前实例化的环境中的特定参数。在本例中，我们发现重力加速度`g`的默认值是10.0\n",
    "\n",
    "通过查阅Gymnasium的文档，该参数可以在调用`gymnasium.make`函数创建环境的过程中指定。OmniSafe是否支持定制化环境的特定参数传递呢？答案是肯定的，具体操作也非常简单："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Logging data to .</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/runs/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">PPOLag-</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">{Pendulum-v1}</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/seed-000-2024-04-09-15-06-01/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">progress.csv</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-06-01/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Save with config in config.json</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;33mSave with config in config.json\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "9.8"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "custom_cfgs.update({'env_cfgs': {'g': 9.8}})\n",
    "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n",
    "agent.agent._env._env.g"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "非常好！重力加速度取值被我们更改为了9.8。我们只需要对`env_cfgs`进行操作，将需要定制参数的键与值指定，即可实现环境的特定参数传递。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 信息记录\n",
    "\n",
    "`Pendulum-v1`任务有许多特定的动态信息，我们将为您介绍如何通过OmniSafe的`Logger`记录这些信息。具体而言，我们将以每幕角速度`angular_velocity`的最大值以及累计值为例为您讲解。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Logging data to .</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/runs/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">PPOLag-</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">{Pendulum-v1}</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/seed-000-2024-04-09-15-06-03/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">progress.csv</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-06-03/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Save with config in config.json</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;33mSave with config in config.json\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008000; text-decoration-color: #008000\">INFO: Start training</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[32mINFO: Start training\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008000; text-decoration-color: #008000\">Warning: trajectory cut off when rollout by epoch at </span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">200.0</span><span style=\"color: #008000; text-decoration-color: #008000\"> steps.</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m200.0\u001b[0m\u001b[32m steps.\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\"> Metrics                         </span>┃<span style=\"font-weight: bold\"> Value                 </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ Metrics/EpRet                   │ -1607.6717529296875   │\n",
       "│ Metrics/EpCost                  │ 0.0                   │\n",
       "│ Metrics/EpLen                   │ 200.0                 │\n",
       "│ Train/Epoch                     │ 0.0                   │\n",
       "│ Train/Entropy                   │ 1.418560266494751     │\n",
       "│ Train/KL                        │ 0.0005777678452432156 │\n",
       "│ Train/StopIter                  │ 1.0                   │\n",
       "│ Train/PolicyRatio/Mean          │ 0.9981198310852051    │\n",
       "│ Train/PolicyRatio/Min           │ 0.9981198310852051    │\n",
       "│ Train/PolicyRatio/Max           │ 0.9981198310852051    │\n",
       "│ Train/PolicyRatio/Std           │ 0.005412393249571323  │\n",
       "│ Train/LR                        │ 0.0                   │\n",
       "│ Train/PolicyStd                 │ 0.9996219277381897    │\n",
       "│ TotalEnvSteps                   │ 200.0                 │\n",
       "│ Loss/Loss_pi                    │ 0.09192709624767303   │\n",
       "│ Loss/Loss_pi/Delta              │ 0.09192709624767303   │\n",
       "│ Value/Adv                       │ -0.4177907109260559   │\n",
       "│ Loss/Loss_reward_critic         │ 16393.2265625         │\n",
       "│ Loss/Loss_reward_critic/Delta   │ 16393.2265625         │\n",
       "│ Value/reward                    │ 0.00719139538705349   │\n",
       "│ Loss/Loss_cost_critic           │ 0.05219484493136406   │\n",
       "│ Loss/Loss_cost_critic/Delta     │ 0.05219484493136406   │\n",
       "│ Value/cost                      │ 0.07949987053871155   │\n",
       "│ Time/Total                      │ 0.2163846492767334    │\n",
       "│ Time/Rollout                    │ 0.18010711669921875   │\n",
       "│ Time/Update                     │ 0.03433847427368164   │\n",
       "│ Time/Epoch                      │ 0.21448636054992676   │\n",
       "│ Time/FPS                        │ 932.4664306640625     │\n",
       "│ Env/Max_angular_velocity        │ 2.9994523525238037    │\n",
       "│ Env/Cumulative_angular_velocity │ 1.0643725395202637    │\n",
       "│ Metrics/LagrangeMultiplier/Mean │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Min  │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Max  │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Std  │ 0.0                   │\n",
       "└─────────────────────────────────┴───────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1mMetrics                        \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue                \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ Metrics/EpRet                   │ -1607.6717529296875   │\n",
       "│ Metrics/EpCost                  │ 0.0                   │\n",
       "│ Metrics/EpLen                   │ 200.0                 │\n",
       "│ Train/Epoch                     │ 0.0                   │\n",
       "│ Train/Entropy                   │ 1.418560266494751     │\n",
       "│ Train/KL                        │ 0.0005777678452432156 │\n",
       "│ Train/StopIter                  │ 1.0                   │\n",
       "│ Train/PolicyRatio/Mean          │ 0.9981198310852051    │\n",
       "│ Train/PolicyRatio/Min           │ 0.9981198310852051    │\n",
       "│ Train/PolicyRatio/Max           │ 0.9981198310852051    │\n",
       "│ Train/PolicyRatio/Std           │ 0.005412393249571323  │\n",
       "│ Train/LR                        │ 0.0                   │\n",
       "│ Train/PolicyStd                 │ 0.9996219277381897    │\n",
       "│ TotalEnvSteps                   │ 200.0                 │\n",
       "│ Loss/Loss_pi                    │ 0.09192709624767303   │\n",
       "│ Loss/Loss_pi/Delta              │ 0.09192709624767303   │\n",
       "│ Value/Adv                       │ -0.4177907109260559   │\n",
       "│ Loss/Loss_reward_critic         │ 16393.2265625         │\n",
       "│ Loss/Loss_reward_critic/Delta   │ 16393.2265625         │\n",
       "│ Value/reward                    │ 0.00719139538705349   │\n",
       "│ Loss/Loss_cost_critic           │ 0.05219484493136406   │\n",
       "│ Loss/Loss_cost_critic/Delta     │ 0.05219484493136406   │\n",
       "│ Value/cost                      │ 0.07949987053871155   │\n",
       "│ Time/Total                      │ 0.2163846492767334    │\n",
       "│ Time/Rollout                    │ 0.18010711669921875   │\n",
       "│ Time/Update                     │ 0.03433847427368164   │\n",
       "│ Time/Epoch                      │ 0.21448636054992676   │\n",
       "│ Time/FPS                        │ 932.4664306640625     │\n",
       "│ Env/Max_angular_velocity        │ 2.9994523525238037    │\n",
       "│ Env/Cumulative_angular_velocity │ 1.0643725395202637    │\n",
       "│ Metrics/LagrangeMultiplier/Mean │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Min  │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Max  │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Std  │ 0.0                   │\n",
       "└─────────────────────────────────┴───────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "(-1607.6717529296875, 0.0, 200.0)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from omnisafe.common.logger import Logger\n",
    "\n",
    "\n",
    "@env_register\n",
    "@env_unregister  # 避免重复运行单元格时产生\"环境已注册\"报错\n",
    "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n",
    "\n",
    "    def __init__(self, env_id, num_envs, device, **kwargs):\n",
    "        super().__init__(env_id, num_envs, device, **kwargs)\n",
    "        self.env_spec_log = {\n",
    "            'Env/Max_angular_velocity': 0.0,\n",
    "            'Env/Cumulative_angular_velocity': 0.0,\n",
    "        }  # 在构造函数中重申并指定\n",
    "\n",
    "    def spec_log(self, logger: Logger) -> None:\n",
    "        for key, value in self.env_spec_log.items():\n",
    "            logger.store({key: value})\n",
    "            self.env_spec_log[key] = 0.0\n",
    "\n",
    "    def step(self, action):\n",
    "        obs, reward, cost, terminated, truncated, info = super().step(action=action)\n",
    "        angle = obs[-1].item()\n",
    "        self.env_spec_log['Env/Max_angular_velocity'] = max(\n",
    "            self.env_spec_log['Env/Max_angular_velocity'], angle\n",
    "        )\n",
    "        self.env_spec_log['Env/Cumulative_angular_velocity'] += angle\n",
    "        return obs, reward, cost, terminated, truncated, info\n",
    "\n",
    "\n",
    "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n",
    "agent.learn()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "太好了！我们成功地在`Logger`中记录了需要的环境特定信息。值得注意的是，在这一过程中我们并没有修改OmniSafe的任何源代码。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 总结\n",
    "我们在本节使用了Gymnasium的经典环境`Pendulum-v1`，为您介绍了将一个社区已有的环境嵌入OmniSafe中所需的必要接口适配与信息提供。我们希望这个教程对您的定制化环境嵌入过程有帮助。如果您想将自己的环境作为OmniSafe官方支持的环境之一，或者在定制化环境中遇到了困难，欢迎在[Issues](https://github.com/PKU-Alignment/omnisafe/issues)，[Pull Requests](https://github.com/PKU-Alignment/omnisafe/pulls)与[Discussions](https://github.com/PKU-Alignment/omnisafe/discussions)模块与我们沟通。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "omnisafe",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
